# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=line-too-long
"""Script for updating tensorflow/tools/compatibility/reorders_v2.py.

To update reorders_v2.py, run:
  bazel run tensorflow/tools/compatibility/update:generate_v2_reorders_map
"""
# pylint: enable=line-too-long

from absl import app
import tensorflow as tf

from tensorflow import python as tf_python  # pylint: disable=unused-import
from tensorflow.python.lib.io import file_io
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_export
from tensorflow.python.util import tf_inspect
from tensorflow.tools.common import public_api
from tensorflow.tools.common import traverse
from tensorflow.tools.compatibility import tf_upgrade_v2

# This import is needed so that TensorFlow python modules are in sys.modules.


_OUTPUT_FILE_PATH = 'third_party/tensorflow/tools/compatibility/reorders_v2.py'
_FILE_HEADER = """# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=line-too-long
\"\"\"List of renames to apply when converting from TF 1.0 to TF 2.0.

THIS FILE IS AUTOGENERATED: To update, please run:
  bazel run tensorflow/tools/compatibility/update:generate_v2_reorders_map
This file should be updated whenever a function is added to
self.reordered_function_names in tf_upgrade_v2.py.
\"\"\"
"""


def collect_function_arg_names(function_names, return_all_args_function_names,
                               function_renames):
  """Determines argument names for reordered function signatures.

  Args:
    function_names: Functions to collect arguments for.
    return_all_args_function_names: Functions to collect all argument names for.
    function_renames: Function renames between v1 and v2.

  Returns:
    Dictionary mapping function names to a list of argument names. Each argument
    name list can have leading `None` elements to indicate that some of the
    function arguments did not change between v1 and v2.
  """
  function_name_v1_to_attr = {}
  function_name_v2_to_attr = {}

  def visit(unused_path, unused_parent, children):
    """Visitor that collects arguments for reordered functions."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])

      api_names_v1 = ['tf.' + name for name in tf_export.get_v1_names(attr)]
      if any(name in function_names for name in api_names_v1):
        for name in api_names_v1:
          function_name_v1_to_attr[name] = attr

      api_names_v2 = ['tf.' + name for name in tf_export.get_v2_names(attr)]
      for name in api_names_v2:
        function_name_v2_to_attr[name] = attr

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  visitor.private_map['tf.compat'] = ['v1', 'v2']
  traverse.traverse(tf.compat.v1, visitor)
  traverse.traverse(tf.compat.v2, visitor)

  def get_arguments_list(attr):
    if tf_inspect.isclass(attr):
      # Get constructor arguments if attr is a class
      arg_list = tf_inspect.getargspec(
          getattr(attr, '__init__'))[0]
      return arg_list[1:]  # skip 'self' argument
    else:
      # Get function arguments.
      # getargspec returns a tuple of (args, varargs, keywords, defaults)
      # we just look at args.
      return tf_inspect.getargspec(attr)[0]

  # Map from reordered function name to its arguments.
  function_to_args = {}

  if any(name not in function_name_v1_to_attr for name in function_names):
    raise ValueError(
        'Symbols not found in `tf.compat.v1`: '
        f'`{"`, `".join(function_names - function_name_v1_to_attr.keys())}`'
    )

  for name_v1, attr_v1 in function_name_v1_to_attr.items():
    args_v1 = get_arguments_list(attr_v1)

    # Per `return_all_args_function_names override`, return all argument names
    # without comparing with v2.
    if name_v1 in return_all_args_function_names:
      function_to_args[name_v1] = args_v1
      continue

    name_v2 = name_v1
    if name_v1 in function_renames:
      name_v2 = function_renames[name_v1]
      # If the rename is simply mapping `tf.x` to `tf.compat.v1.x`, there is no
      # change in the arguments, we shouldn't have it in the list.
      if name_v2.startswith('tf.compat.v1.'):
        raise ValueError(f'Symbol `{name_v1}` is renamed to `{name_v2}`, '
                         'no need to add keyword argument names, '
                         'remove from `reordered_function_names`')

    if name_v2 not in function_name_v2_to_attr:
      raise ValueError(f'Symbol `{name_v2}` not found in `tf.compat.v2`')
    args_v2 = get_arguments_list(function_name_v2_to_attr[name_v2])

    # If there is no change in the arguments, we shouldn't have it in the list.
    if args_v1 == args_v2:
      raise ValueError(f'Symbol `{name_v1}` has no changes in arguments, '
                       'no need to add keyword argument names, '
                       'remove from `reordered_function_names`')

    # Compare v1/v2 argument names and put `None` as long as they're identical.
    needed_arg_names = []
    same_so_far = True
    for index, arg in enumerate(args_v1):
      if same_so_far and index < len(args_v2) and arg == args_v2[index]:
        needed_arg_names.append(None)
      else:
        same_so_far = False
        needed_arg_names.append(arg)
    function_to_args[name_v1] = needed_arg_names

  return function_to_args


def get_reorder_line(name, arg_list):
  return '    \'%s\': %s' % (name, str(arg_list))


def update_reorders_v2(output_file_path):
  """Writes a Python dictionary mapping function name to argument order.

  Args:
    output_file_path: File path to write output to. Any existing contents
      would be replaced.
  """
  spec = tf_upgrade_v2.TFAPIChangeSpec()
  reordered_function_names = spec.reordered_function_names
  # We assume that `function_transformers` operate on the keyword arguments, so
  # for those we will expand all the arguments
  need_kwargs_function_names = spec.function_transformers.keys()
  function_renames = spec.symbol_renames

  all_reorders = collect_function_arg_names(reordered_function_names,
                                            need_kwargs_function_names,
                                            function_renames)

  # List of reorder lines to write to output file in the form:
  #   'tf.function_name': ['arg1', 'arg2', ...]
  rename_lines = [
      get_reorder_line(name, arg_names)
      for name, arg_names in all_reorders.items()]
  renames_file_text = '%sreorders = {\n%s\n}\n' % (
      _FILE_HEADER, ',\n'.join(sorted(rename_lines)))
  file_io.write_string_to_file(output_file_path, renames_file_text)


def main(unused_argv):
  update_reorders_v2(_OUTPUT_FILE_PATH)


if __name__ == '__main__':
  app.run(main=main)
